"""
探究了5个transformations
0: y
1: x
2: diagonal
3: cycle
4: rotation
5: random translation
移动类的transformations无明显差异，通过threshold可以区分rotation
"""
import argparse
from collections import defaultdict

import wandb
from torch import nn, optim
from tqdm import trange

from exps.models import *
from exps.patterns.data_generator import *
from utils.helpers import set_seed
from utils.visualize import *


hyperparameter_defaults = dict(
    batch_size=20,
    learning_rate=0.0005,
    epochs=1001,
    transformation=4,
    angle = 0,
    loss='betaH',
    img_id=0,
    random_seed=224
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args

set_seed(config.random_seed)
wandb.init(project="exps", config=config, group='transformation', tags=['impedance'])

def train(model, epochs,log=True):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    z=torch.linspace(-2,2,len(dataset)).cuda()
    for e in trange(epochs):
        storer = defaultdict(list)

        img = dataset.cuda()
        recon_batch = model(z.unsqueeze(1))
        loss = F.binary_cross_entropy(recon_batch,img)
        opt.zero_grad()
        loss.backward()
        opt.step()
        storer['loss'].append(loss.item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if e % 100 == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
            model.cpu()
            model.eval()
            m = nn.Module()
            m.decoder = model
            fig = plt_sample_traversal(None, m, 7, range(latent_dim), r=2)
            storer['traversal']=fig
            model.cuda()
            model.train()
            plt.close()
        storer['epoch'] = e
        if log:
            wandb.log(storer, sync=False)
    return storer


# generate data
img_id = config.img_id
epochs = config.epochs

img = torch.load(f'patterns/{img_id}.pat')
imgs = gen_rotation(img, img.shape)[0]
angle = config.angle
if config.transformation==0:
    dataset= imgs[angle,:,0].reshape(-1,1,64,64)
elif config.transformation==1:
    dataset = imgs[angle, 0, :].reshape(-1, 1, 64, 64)
elif config.transformation == 2:
    dataset = [imgs[angle, i, i] for i in range(40)]
    dataset = torch.stack(dataset).reshape(-1, 1, 64, 64)
elif config.transformation == 3:
    dataset= []
    for t in np.linspace(np.pi/2,-np.pi/2+0.001,40):
        x = 20*np.cos(t)
        y = 20 - 20 * np.sin(t)
        dataset.append(imgs[angle,int(x),int(y)])
    dataset = torch.stack(dataset).reshape(-1, 1, 64, 64)
elif config.transformation == 4:
    dataset = imgs[:, 0, 0].reshape(-1, 1, 64, 64)
elif config.transformation == 5:
    dataset = []
    for t in np.random.permutation(40*40)[:40]:
        x =int(t//40)
        y = int(t%40)
        dataset.append(imgs[angle,y,x])
    dataset = torch.stack(dataset).reshape(-1, 1, 64, 64)
else:
    raise  NotImplementedError

latent_dim = 1
hid_channels = 16
hidden_dim = 128
cnn_kwargs = dict(stride=2, padding=1, kernel_size=4)
decoder = nn.Sequential(*
                        ([nn.Linear(latent_dim, hid_channels * 16), nn.LeakyReLU(),
                          Reshape((-1, hid_channels, 4, 4))] +
                         [DecoderUnit(in_channels=hid_channels, out_channels=hid_channels,
                                      **cnn_kwargs) for i in range(3)]) +
                        [nn.ConvTranspose2d(in_channels=hid_channels, out_channels=1, **cnn_kwargs),
                         nn.Sigmoid()])

decoder.cuda()
decoder.train()


iterations =  epochs
storer = train(decoder,  epochs)
plt.show()